from torch.utils.data import Dataset
from PIL import Image
import os


class Demo01Dataset(Dataset):
    """
    学习自定义 DataSet
    """



    def __init__(self, image_dir, label_dir):
        self.image_dir = image_dir
        self.label_dir = label_dir
        #self.path = os.path.join(self.image_dir,  )
        self.images = os.listdir(self.image_dir)
        self.a = 1


    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_name = self.images[index]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path)
        image_label = image_name # todo
        return image, image_label


if __name__ == '__main__':

    dataset1 = Demo01Dataset("D:/Workspace/github/HelloPythonFramework/pytorch-lesson/pytorch-lesson02/data/train/ants_image",
                             "D:/Workspace/github/HelloPythonFramework/pytorch-lesson/pytorch-lesson02/data/train/ants_label")

    print("dataset1 长度：", len(dataset1))

    image, label = dataset1[0]

    print(label)

    image.show()

    dataset2 = Demo01Dataset("D:/Workspace", "D:/")

    print("dataset2 长度：", len(dataset2))

    image, label = dataset2[0]

    dataset3 = dataset1 + dataset2

    print("dataset3 长度：", len(dataset3))
